{
 "cells": [
  {
   "cell_type": "markdown",
   "source": [
    "# Custom Operations"
   ],
   "metadata": {
    "collapsed": false
   },
   "id": "36a6807951adb09e"
  },
  {
   "cell_type": "markdown",
   "source": [
    "The core design of the library allows us to define our custom operations at different levels."
   ],
   "metadata": {
    "collapsed": false
   },
   "id": "dac1cf09c8d4894a"
  },
  {
   "cell_type": "code",
   "outputs": [],
   "source": [
    "from traceback import print_exc\n",
    "\n",
    "import numpy as np\n",
    "\n",
    "from braandket import ArrayLike, PureStateTensor, pi\n",
    "from braandket_circuit import BnkParticle, BnkRuntime, BnkState, H, M, QOperation, QParticle, QSystemStruct, Ry, Rz, X, \\\n",
    "    allocate_qubits, register_apply_impl\n",
    "from braandket_circuit.utils import iter_struct"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2024-01-25T08:33:44.952823800Z",
     "start_time": "2024-01-25T08:33:41.376929400Z"
    }
   },
   "id": "736087da4a9c0b79",
   "execution_count": 1
  },
  {
   "cell_type": "markdown",
   "source": [
    "## 1. Overriding the `__call__` method\n",
    "\n",
    "As we have seen in the previous sections, we can subclass the `QOperation` class and override the `__call__` method to define our custom operations."
   ],
   "metadata": {
    "collapsed": false
   },
   "id": "cec46e77077dee93"
  },
  {
   "cell_type": "markdown",
   "source": [
    "Here is an example of a custom parameterized operation `Uzyz` as in [Parameterized Circuit](example4_parameterized_circuit.ipynb)."
   ],
   "metadata": {
    "collapsed": false
   },
   "id": "be8574eb4060498e"
  },
  {
   "cell_type": "code",
   "outputs": [],
   "source": [
    "class Uzyz(QOperation):\n",
    "    def __init__(self, thetas: ArrayLike):\n",
    "        super().__init__()\n",
    "        self.thetas = thetas\n",
    "\n",
    "    def __call__(self, qubit: QParticle):\n",
    "        Rz(self.thetas[0])(qubit)\n",
    "        Ry(self.thetas[1])(qubit)\n",
    "        Rz(self.thetas[2])(qubit)"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2024-01-25T08:33:44.960460900Z",
     "start_time": "2024-01-25T08:33:44.955698200Z"
    }
   },
   "id": "8da690a9f54d91e1",
   "execution_count": 2
  },
  {
   "cell_type": "markdown",
   "source": [
    "Here is an other example of a custom operation `FourierSampling`"
   ],
   "metadata": {
    "collapsed": false
   },
   "id": "de0d489995391fb"
  },
  {
   "cell_type": "code",
   "outputs": [],
   "source": [
    "class FourierSampling(QOperation):\n",
    "    def __call__(self, *qubits: QParticle):\n",
    "        for qubit in qubits:\n",
    "            H(qubit)"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2024-01-25T08:33:44.961470100Z",
     "start_time": "2024-01-25T08:33:44.959239400Z"
    }
   },
   "id": "1cbf445b0d838718",
   "execution_count": 3
  },
  {
   "cell_type": "markdown",
   "source": [
    "## 2. Without overriding the `__call__` method\n",
    "\n",
    "It is worth noting that overriding the `__call__` method optional.\n"
   ],
   "metadata": {
    "collapsed": false
   },
   "id": "3f7ffedd86def525"
  },
  {
   "cell_type": "markdown",
   "source": [
    "For example, we want to define an operation $P$, that has a phase shift only at the last diagonal element. \n",
    "\n",
    "$$\n",
    "P(\\theta)=\\begin{pmatrix} 1 & & & \\\\ & \\ddots & & \\\\ & & 1 & \\\\ & & & e^{i\\theta} \\end{pmatrix}\n",
    "$$\n",
    "\n",
    "But we are not clear yet how to implement it with the existing operations, then we can define it without overriding the `__call__` method."
   ],
   "metadata": {
    "collapsed": false
   },
   "id": "63efa1cbefd7bd08"
  },
  {
   "cell_type": "code",
   "outputs": [],
   "source": [
    "class LastDiagonalPhase(QOperation):\n",
    "    def __init__(self, theta: ArrayLike):\n",
    "        super().__init__()\n",
    "        self.theta = theta"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2024-01-25T08:33:44.970101300Z",
     "start_time": "2024-01-25T08:33:44.963470300Z"
    }
   },
   "id": "bce3458f108211a",
   "execution_count": 4
  },
  {
   "cell_type": "markdown",
   "source": [
    "Since we have given no information about how `LastDiagonalPhase` works, calling it will raise an error. Although such operations cannot be actually called, it can be used in many other processes like visualization and compilation.  "
   ],
   "metadata": {
    "collapsed": false
   },
   "id": "d80618270353915b"
  },
  {
   "cell_type": "code",
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Traceback (most recent call last):\n",
      "  File \"C:\\Users\\keli\\AppData\\Local\\Temp\\ipykernel_8276\\4003537457.py\", line 4, in <module>\n",
      "    LastDiagonalPhase(pi)(q0, q1)\n",
      "  File \"C:\\Users\\keli\\Develop\\BraAndKet\\braandket-circuit\\braandket_circuit\\basics\\operation.py\", line 27, in __call__\n",
      "    return apply(get_current_runtime(), self, *args)\n",
      "           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
      "  File \"C:\\Users\\keli\\Develop\\BraAndKet\\braandket-circuit\\braandket_circuit\\traits\\apply\\apply.py\", line 75, in apply\n",
      "    raise impls_error[0]\n",
      "  File \"C:\\Users\\keli\\Develop\\BraAndKet\\braandket-circuit\\braandket_circuit\\traits\\apply\\apply.py\", line 69, in apply\n",
      "    return impl(rt, op, *args)\n",
      "           ^^^^^^^^^^^^^^^^^^^\n",
      "  File \"C:\\Users\\keli\\Develop\\BraAndKet\\braandket-circuit\\braandket_circuit\\traits_impls\\apply\\impls.py\", line 28, in default_impl\n",
      "    raise NotImplementedError\n",
      "NotImplementedError\n"
     ]
    }
   ],
   "source": [
    "q0, q1 = allocate_qubits(2)\n",
    "\n",
    "try:\n",
    "    LastDiagonalPhase(pi)(q0, q1)\n",
    "except NotImplementedError:\n",
    "    print_exc()"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2024-01-25T08:33:44.971614900Z",
     "start_time": "2024-01-25T08:33:44.969633100Z"
    }
   },
   "id": "5103999d0c1c3b20",
   "execution_count": 5
  },
  {
   "cell_type": "markdown",
   "source": [
    "## 3. Registering runtime-specific implementation\n",
    "\n",
    "In some cases, although we won't describe the custom operation with other existing operations, we clearly know how it works, and we want it to be callable, then we can register a runtime-specific implementation via `register_apply_impl`.\n"
   ],
   "metadata": {
    "collapsed": false
   },
   "id": "15dbcfa6a1615980"
  },
  {
   "cell_type": "markdown",
   "source": [
    "Take the `LastDiagnalPhase` as an example, we can implement it when it is running on a simulator based on BnkRuntime. Yet we only implement the calculation for pure states with numpy values."
   ],
   "metadata": {
    "collapsed": false
   },
   "id": "627411b2f418f288"
  },
  {
   "cell_type": "code",
   "outputs": [],
   "source": [
    "@register_apply_impl(BnkRuntime, LastDiagonalPhase)\n",
    "def last_diagonal_phase_impl(rt: BnkRuntime, op: LastDiagonalPhase, *args: QSystemStruct):\n",
    "    particles = tuple(particle for particle in iter_struct(args, atom_typ=BnkParticle))\n",
    "    state = BnkState.prod(*(particle.state for particle in particles))\n",
    "    if not isinstance(state.tensor, PureStateTensor):\n",
    "        raise NotImplementedError\n",
    "\n",
    "    values = state.tensor.values()\n",
    "    if not isinstance(values, np.ndarray):\n",
    "        raise NotImplementedError\n",
    "\n",
    "    spaces = state.tensor.ket_spaces\n",
    "    shape = values.shape\n",
    "    values = rt.backend.reshape(values, [-1])\n",
    "\n",
    "    phase = np.ones_like(values, dtype=np.complex128)\n",
    "    phase[-1] *= np.exp(1j * op.theta)\n",
    "    values = values * phase\n",
    "\n",
    "    values = rt.backend.reshape(values, shape)\n",
    "    state.tensor = PureStateTensor.of(values, spaces, backend=rt.backend)"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2024-01-25T08:33:45.021859900Z",
     "start_time": "2024-01-25T08:33:44.976620500Z"
    }
   },
   "id": "431a759bb6baa422",
   "execution_count": 6
  },
  {
   "cell_type": "markdown",
   "source": [
    "After registering the runtime-specific implementation, we can now call the `LastDiagonalPhase` operation."
   ],
   "metadata": {
    "collapsed": false
   },
   "id": "55fb096df79ae025"
  },
  {
   "cell_type": "code",
   "outputs": [],
   "source": [
    "LastDiagonalPhase(pi)(q0, q1)"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2024-01-25T08:33:45.021859900Z",
     "start_time": "2024-01-25T08:33:44.978627500Z"
    }
   },
   "id": "160cec4e1dbc9611",
   "execution_count": 7
  },
  {
   "cell_type": "markdown",
   "source": [
    "Then we can verify the effect of the `LastDiagonalPhase` operation."
   ],
   "metadata": {
    "collapsed": false
   },
   "id": "a18ec32da7ae3ff1"
  },
  {
   "cell_type": "code",
   "outputs": [
    {
     "data": {
      "text/plain": "MeasurementResult(target=<BnkParticle name=None>, value=1, prob=array(1.))"
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "q0, q1 = allocate_qubits(2)\n",
    "X(q0)\n",
    "\n",
    "# this part is equivalent to CNOT(q0, q1)\n",
    "H(q1)\n",
    "LastDiagonalPhase(pi)(q0, q1)\n",
    "H(q1)\n",
    "\n",
    "M(q1)"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2024-01-25T08:33:45.022860Z",
     "start_time": "2024-01-25T08:33:44.983375500Z"
    }
   },
   "id": "d4665c539c090524",
   "execution_count": 8
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
